{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Text Classification: How to run inference on the endpoint you have created?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import boto3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's put in some example texts. You can put in any text you like, the model will predict whether it's a positive movie review or negative.\n",
    "These examples are taken from SST2 dataset downloaded from [TensorFlow](https://www.tensorflow.org/datasets/catalog/glue#gluesst2). [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0). [Dataset Homepage](https://nlp.stanford.edu/sentiment/index.html). Citations:\n",
    "<sub><sup>   \n",
    "@inproceedings{socher2013recursive,\n",
    "  title={Recursive deep models for semantic compositionality over a sentiment treebank},\n",
    "  author={Socher, Richard and Perelygin, Alex and Wu, Jean and Chuang, Jason and Manning, Christopher D and Ng, Andrew and Potts, Christopher},\n",
    "  booktitle={Proceedings of the 2013 conference on empirical methods in natural language processing},\n",
    "  pages={1631--1642},\n",
    "  year={2013}\n",
    "}\n",
    "@inproceedings{wang2019glue,\n",
    "  title={ {GLUE}: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding},\n",
    "  author={Wang, Alex and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R.},\n",
    "  note={In the Proceedings of ICLR.},\n",
    "  year={2019}\n",
    "}\n",
    "</sup></sub>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "text1 = 'i simply love this product'\n",
    "text2 = 'worst product ever'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Query endpoint that you have created"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Inference:\n",
      "Input text: 'i simply love this product'\n",
      "Model prediction: [-2.11609411, -2.46743321, -1.08920491, 1.62307894, 4.60758972]\n",
      "Model prediction mapped to labels: \u001b[1m5-star\u001b[0m\n",
      "\n",
      "Inference:\n",
      "Input text: 'worst product ever'\n",
      "Model prediction: [3.88741, 1.64652145, -0.904055178, -2.87982726, -1.96883416]\n",
      "Model prediction mapped to labels: \u001b[1m1-star\u001b[0m\n",
      "\n"
     ]
    }
   ],
   "source": [
    "label_map = {0: \"1-star\", 1: \"2-star\", 2: \"3-star\", 3: \"4-star\", 4: \"5-star\"}\n",
    "newline, bold, unbold = '\\n', '\\033[1m', '\\033[0m'\n",
    "def query_endpoint(encoded_text):\n",
    "    endpoint_name = 'jumpstart-tf-tc-bert-en-uncased-l-12-h-768-a-12-2'\n",
    "    client = boto3.client('runtime.sagemaker')\n",
    "    response = client.invoke_endpoint(EndpointName=endpoint_name, ContentType='application/x-text', Body=encoded_text)\n",
    "    model_predictions = json.loads(response['Body'].read())['predictions'][0]\n",
    "    return model_predictions\n",
    "\n",
    "for text in [text1, text2]:\n",
    "    model_predictions = query_endpoint(text.encode('utf-8'))\n",
    "    class_index = model_predictions.index(max(model_predictions))\n",
    "    print (f\"Inference:{newline}\"\n",
    "            f\"Input text: '{text}'{newline}\"\n",
    "            f\"Model prediction: {model_predictions}{newline}\"\n",
    "            f\"Model prediction mapped to labels: {bold}{label_map[class_index]}{unbold}{newline}\")"
   ]
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "Python 3 (Data Science)",
   "language": "python",
   "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
